import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
import torch.nn as nn
import torch, os
import torch.nn.functional as F

from loss import weighted_mae
from utils import add_parser, reserve_schedule_sampling_exp, schedule_sampling


class PredRNN(LightningModule):

    def __init__(self, downscale_factor, num_hidden, loss_fx, weights_prec, thresholds_prec, width, filter_size,
                 stride, layer_norm, input_channels, predict_channels, input_length, target_length, learning_rate,
                 reverse_scheduled_sampling, r_sampling_step_1, r_sampling_step_2, r_exp_alpha, scheduled_sampling,
                 sampling_stop_iter, sampling_changing_rate, param_a, param_b, param_c, param_d, visual_train_steps,
                 visual_val_steps, visual_c1_vmin, visual_c1_vmax, visual_c1_unorm, train_log_steps, val_log_steps,
                 test_save_path, **kwargs):
        """
        ///初始化
        :param downscale_factor: int，将输入图像裁剪成patch，该参数代表裁剪的块数， 数量为downscale_factor**2
        :param num_hidden: type[list], 列表的长度代表隐藏层的数量，里面的值代表了每一层的隐藏层数量
        :param width: int，输入原始图像的大小，指裁剪前的图像的大小
        :param filter_size: int, 卷积核的大小
        :param stride: int, 卷积的步长
        :param layer_norm: bool，if layer_norm == 1则采用layer_norm, 默认为0
        :param num_channels_in: int，输入的数据的channel，单一数据源就是1
        :param num_channels_out: int，输出的数据的channel， 单一数据源就是1
        :param input_length: int，输入数据的长度
        :param target_length: int，预测数据的长度
        :param learning_rate: float，模型的学习率
        :param reverse_scheduled_sampling: bool，是否采用reverse_scheduled_sampling的策略
        :param r_sampling_step_1: float，reverse_scheduled_sampling第一阶段的步数
        :param r_sampling_step_2: float，reverse_scheduled_sampling第二阶段的步数
        :param r_exp_alpha: float，reverse_scheduled_sampling的衰减参数
        :param scheduled_sampling: bool，是否采用scheduled_sampling的策略
        :param sampling_stop_iter: int，scheduled_sampling停止的步数
        :param sampling_changing_rate: scheduled_sampling的衰减参数
        :param param_a: loss函数的input_loss的加权
        :param param_b: loss函数的target_loss的加权
        :param param_c: loss函数的解耦损失的加权encoder
        :param param_d: loss函数的解耦损失的加权decoder
        :param visual_train_steps: int，训练阶段每保存img至tensorboard的迭代步数
        :param visual_val_steps: int， 验证阶段每保存img至tensorboard的迭代步数
        :param log_steps: int， 每保存scalar至tensorboard的迭代步数
        :param kwargs:
        """

        super(PredRNN, self).__init__()
        self.save_hyperparameters()

        self.loss_fx = weighted_mae(self.hparams.weights_prec,
                                       self.hparams.thresholds_prec,)

        self.num_layers = len(num_hidden)
        adapter_num_hidden = num_hidden[0]

        self.pixel_unshuffle = nn.PixelUnshuffle(self.hparams.downscale_factor)
        self.pixel_shuffle = nn.PixelShuffle(self.hparams.downscale_factor)

        self.channels_in = input_channels * self.hparams.downscale_factor ** 2
        self.channels_out = predict_channels * self.hparams.downscale_factor ** 2
        patched_width = width//downscale_factor
        self.encoder = PredRNN_encoder_v2(num_hidden, patched_width, filter_size, stride, layer_norm,
                                          adapter_num_hidden, self.num_layers, self.channels_in,
                                          self.channels_out)


    def forward(self, frames_pwafs_input, mask_input, frames_cmpas_output, mask_output, **kwargs):
        """
        ///forward过程
        :param frames_tensor_input: tensor[batch, input_length, channel, height, width]，模型的输入
        :param mask_input: tensor[batch, input_length-1, channel, height, width]，scheduled_sampling用于input阶段的mask输入
        :param frames_tensor_output: tensor[batch, target_length, channel, height, width]，模型的理想输出（仅用于训练，测试和预测设置为0）
        :param mask_output: tensor[batch, target_length-1, channel, height, width]，scheduled_sampling用于output阶段的mask输入
        :param kwargs:
        :return:
        """

        frames_cmpas_output = self.pixel_unshuffle(frames_cmpas_output)
        frames_pwafs_input = self.pixel_unshuffle(frames_pwafs_input)

        # [batch, length, channel, height, width]
        batch, _, _, height, width = frames_pwafs_input.shape
        encoder_frames = []
        decoder_frames = []
        h_t = []
        c_t = []
        for i in range(self.num_layers):
            zeros = torch.zeros([batch, self.hparams.num_hidden[i], height, width], device=self.device)
            h_t.append(zeros)
            c_t.append(zeros)
        memory = torch.zeros([batch, self.hparams.num_hidden[0], height, width], device=self.device)
        decouple_loss_encoder_sum = 0
        for t in range(self.hparams.input_length):
            if t == 0:
                input_encoder = torch.cat([frames_pwafs_input[:, 0, 0:9,...], frames_pwafs_input[:, 1]], dim=1)
                # print(input_encoder.shape)
            else:
                encoder_frames.append(x_gen_encoder)
                input_encoder_stations = mask_input[:, t - 1] * frames_pwafs_input[:, t, 0:9, ...] + (
                            1 - mask_input[:, t - 1]) * x_gen_encoder[:,:9,...]
                input_encoder = torch.cat([input_encoder_stations, frames_pwafs_input[:, t+1]], dim=1)
            h_t, c_t, memory, x_gen_encoder, decouple_loss_encoder = self.encoder(input_encoder, t==0, h_t, c_t, memory)
            decouple_loss_encoder_sum += decouple_loss_encoder

        decoder_frames.append(x_gen_encoder)
        decouple_loss_decoder_sum = 0
        for t in range(self.hparams.target_length - 1):
            input_decoder_stations = mask_output[:, t] * frames_cmpas_output[:, self.hparams.input_length + t]\
                                     + (1 - mask_output[:, t]) * x_gen_encoder[:,:9,...]
            input_decoder = torch.cat([input_decoder_stations, frames_pwafs_input[:, self.hparams.input_length + t + 1]], dim=1)
            _, _, _, x_gen_encoder, decouple_loss_decoder = self.encoder(input_decoder, False)

            decoder_frames.append(x_gen_encoder)
            decouple_loss_decoder_sum += decouple_loss_decoder

        encoder_frames = torch.stack(encoder_frames, dim=1)
        decoder_frames = torch.stack(decoder_frames, dim=1)

        encoder_frames = self.pixel_shuffle(encoder_frames)
        decoder_frames = self.pixel_shuffle(decoder_frames)

        return encoder_frames, decoder_frames, decouple_loss_encoder_sum, decouple_loss_decoder_sum

    def training_step(self, batch, batch_idx):
        input_tensor_pwafs, target_tensor_cmpas = batch
        batch_size = input_tensor_pwafs.shape[0]

        if self.hparams.reverse_scheduled_sampling == 1:
            input_flag, target_flag, r_eta, eta = reserve_schedule_sampling_exp(self.global_step, self.hparams.r_sampling_step_1,
            self.hparams.r_sampling_step_2, self.hparams.r_exp_alpha, batch_size,
            self.hparams.input_length, self.hparams.target_length, self.device)
        elif self.hparams.scheduled_sampling ==1:
            input_flag, target_flag, r_eta, eta = schedule_sampling(self.global_step, self.hparams.sampling_stop_iter,
            self.hparams.sampling_changing_rate, batch_size, self.hparams.input_length, self.hparams.target_length,
                                                        self.device)
        else:
            input_flag=torch.ones((1, self.hparams.input_length-1, 1, 1, 1), device=self.device)
            target_flag=torch.zeros((1, self.hparams.target_length-1, 1, 1, 1), device=self.device)
            r_eta = 1
            eta = 0

        encoder_frames, decoder_frames, decouple_loss_encoder, decouple_loss_decoder = self(input_tensor_pwafs,
                                                                input_flag, target_tensor_cmpas, target_flag)

        input_loss = self.loss_fx(encoder_frames, target_tensor_cmpas[:,1:self.hparams.input_length,...])
        target_loss = self.loss_fx(decoder_frames, target_tensor_cmpas[:,self.hparams.input_length:])

        loss = self.hparams.param_a * input_loss + \
               self.hparams.param_b * target_loss +\
               self.hparams.param_c * decouple_loss_encoder + \
               self.hparams.param_d * decouple_loss_decoder

        return loss

    def validation_step(self, batch, batch_idx):
        input_tensor_pwafs, target_tensor_cmpas = batch

        input_flag = torch.ones(1, self.hparams.input_length-1, 1, 1, 1, device=self.device)
        target_flag = torch.zeros(1, self.hparams.target_length-1, 1, 1, 1, device=self.device)
        _, decoder_frames, _, _ = self(input_tensor_pwafs, input_flag, target_tensor_cmpas, target_flag)
        decoder_frames = torch.clip(decoder_frames, 0, 1)
        valid_loss_fx = self.loss_fx(decoder_frames, target_tensor_cmpas[:, self.hparams.input_length:, ...])
        metrics_pred = {'valid_loss_fx': valid_loss_fx.item()}
        self.log_dict(metrics_pred, on_step=False, on_epoch=True, prog_bar=True)
        return metrics_pred


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("PredRNN_v2")
        add_parser(parser)
        parser.add_argument('--num_hidden', nargs='+', type=int, default=[128, 128, 128, 128])
        parser.add_argument('--weights_prec', nargs='+', type=float, default=[1, 1, 2.5, 5, 10, 20])
        parser.add_argument('--thresholds_prec', nargs='+', type=float, default=[0.0083, 0.0167, 0.083, 0.167, 0.333])
        parser.add_argument('--filter_size', type=int, default=5)
        parser.add_argument('--stride', type=int, default=1)
        parser.add_argument('--layer_norm', type=bool, default=0)
        parser.add_argument('--reverse_scheduled_sampling', type=int, default=1)
        parser.add_argument('--r_sampling_step_1', type=int, default=2500)
        parser.add_argument('--r_sampling_step_2', type=int, default=5000)
        parser.add_argument('--r_exp_alpha', type=float, default=500)
        parser.add_argument('--scheduled_sampling', type=int, default=0)
        parser.add_argument('--sampling_stop_iter', type=int, default=50000)
        parser.add_argument('--sampling_changing_rate', type=float, default=0.00002)
        parser.add_argument('--param_a', type=float, default=1)
        parser.add_argument('--param_b', type=float, default=2)
        parser.add_argument('--param_c', type=float, default=0.05)
        parser.add_argument('--param_d', type=float, default=0.05)
        parser.add_argument("--visual_c1_vmin", type=float, default=-2)
        parser.add_argument("--visual_c1_vmax", type=float, default=46)
        parser.add_argument("--visual_c1_unorm", type=float, default=60)
        return parent_parser

class PredRNN_encoder_v2(pl.LightningModule):

    def __init__(self, num_hidden, patched_width, filter_size, stride, layer_norm,
                 adapter_num_hidden, num_layers, channels_in, channels_out):
        super(PredRNN_encoder_v2, self).__init__()
        cell_list_encoder = []
        for i in range(num_layers):
            in_channel = channels_in if i == 0 else num_hidden[i - 1]
            cell_list_encoder.append(
                SpatioTemporalLSTMCell(in_channel, num_hidden[i], patched_width, filter_size,
                                       stride, layer_norm))

        self.cell_list_encoder = nn.ModuleList(cell_list_encoder)
        self.conv_last_encoder = nn.Conv2d(num_hidden[num_layers - 1], channels_out,
                                           kernel_size=1, stride=1, padding=0, bias=False)

        self.adapter_encoder = nn.Conv2d(adapter_num_hidden, adapter_num_hidden, 1, stride=1, padding=0, bias=False)
        self.h_t = []
        self.c_t = []
        self.num_layers= num_layers
        self.num_hidden = num_hidden

    def forward(self, input_, first_timestep=False, h_t=None, c_t=None, memory=None):
        """
        ///encoder模块的forward过程
        :param input_:bottom层的输入
        :param first_timestep:是否为第一个
        :return:
        """

        if first_timestep:
            self.h_t = h_t
            self.c_t = c_t
            self.memory = memory

        decouple_loss = 0
        for i, cell in enumerate(self.cell_list_encoder):
            if i==0: ## bottom layer
                self.h_t[i], self.c_t[i], self.memory, delta_c, delta_m = cell(input_, self.h_t[i], self.c_t[i],
                                                                               self.memory)
            else:
                self.h_t[i], self.c_t[i], self.memory, delta_c, delta_m = cell(self.h_t[i-1], self.h_t[i], self.c_t[i],
                                                                               self.memory)

            delta_c_list = F.normalize(self.adapter_encoder(delta_c).view(delta_c.shape[0], delta_c.shape[1], -1),
                                          dim=2)
            delta_m_list = F.normalize(self.adapter_encoder(delta_m).view(delta_m.shape[0], delta_m.shape[1], -1),
                                          dim=2)
            decouple_loss += torch.mean(torch.abs(torch.cosine_similarity(delta_c_list, delta_m_list, dim=2)))

        x_gen_encoder = self.conv_last_encoder(self.h_t[self.num_layers - 1])
        # x_gen_encoder = torch.cat([x_gen_encoder[:, :9, ...], torch.sigmoid(x_gen_encoder[:, 9:, ...])], dim=1)

        return self.h_t, self.c_t, self.memory, x_gen_encoder, decouple_loss/self.num_layers

class SpatioTemporalLSTMCell(pl.LightningModule):
    def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm):
        super(SpatioTemporalLSTMCell, self).__init__()

        self.num_hidden = num_hidden
        self.padding = filter_size // 2
        self._forget_bias = 1.0
        if layer_norm:
            self.conv_x = nn.Sequential(
                nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
                nn.LayerNorm([num_hidden * 7, width, width])
            )
            self.conv_h = nn.Sequential(
                nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
                nn.LayerNorm([num_hidden * 4, width, width])
            )
            self.conv_m = nn.Sequential(
                nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
                nn.LayerNorm([num_hidden * 3, width, width])
            )
            self.conv_o = nn.Sequential(
                nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
                nn.LayerNorm([num_hidden, width, width])
            )
        else:
            self.conv_x = nn.Sequential(
                nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
            )
            self.conv_h = nn.Sequential(
                nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
            )
            self.conv_m = nn.Sequential(
                nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
            )
            self.conv_o = nn.Sequential(
                nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding, bias=False),
            )
        self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False)


    def forward(self, x_t, h_t, c_t, m_t):
        x_concat = self.conv_x(x_t)
        h_concat = self.conv_h(h_t)
        m_concat = self.conv_m(m_t)
        i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)
        i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
        i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)

        i_t = torch.sigmoid(i_x + i_h)
        f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
        g_t = torch.tanh(g_x + g_h)

        delta_c = i_t * g_t
        c_new = f_t * c_t + delta_c

        i_t_prime = torch.sigmoid(i_x_prime + i_m)
        f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
        g_t_prime = torch.tanh(g_x_prime + g_m)

        delta_m = i_t_prime * g_t_prime
        m_new = f_t_prime * m_t + delta_m

        mem = torch.cat((c_new, m_new), 1)
        o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
        h_new = o_t * torch.tanh(self.conv_last(mem))

        return h_new, c_new, m_new, delta_c, delta_m
